import math
import copy
import gym
import random
import numpy as np
import statistics
import pickle

# Import your updated custom/stochastic envs
import Continuous_CartPole
import Continuous_Pendulum
import continuous_mountain_car
import continuous_acrobot
import improved_hopper
import improved_ant
import improved_walker2d

# Import the fast high-dimensional environments
try:
    import fast_high_dim_envs
except ImportError as e:
    print(f"Warning: Could not import fast_high_dim_envs: {e}")
    print("Make sure fast_high_dim_envs.py is in the same directory")

from SnapshotENV import SnapshotEnv

####################################################################
# 1) environment IDs - UPDATED WITH FAST HIGH-DIM ENVS
####################################################################
env_names = [
    "Continuous-CartPole-v0",
    "StochasticPendulum-v0",
    "StochasticMountainCarContinuous-v0",
    "StochasticContinuousAcrobot-v0",
    "ImprovedHopper-v0",
    "ImprovedWalker2d-v0",
    "ImprovedAnt-v0"
]

####################################################################
# 2) noise configs or constructor kwargs for each environment - UPDATED
####################################################################
ENV_NOISE_CONFIG = {
    "Continuous-CartPole-v0": {
        "action_noise_scale": 0.05, #0.05
        "dynamics_noise_scale": 0.5, #0.01
        "obs_noise_scale": 0.0
    },
    "StochasticPendulum-v0": {
        "action_noise_scale": 0.02, #0.02,
        "dynamics_noise_scale": 0.1, #0.01,
        "obs_noise_scale": 0.01
        # or pass "g": 9.8 if you want a different gravity, etc.
    },
    "StochasticMountainCarContinuous-v0": {
        "action_noise_scale":  0.05, #0.03,
        "dynamics_noise_scale": 0.5, #0.01,
        "obs_noise_scale": 0.0
    },
    "StochasticContinuousAcrobot-v0": {
        "action_noise_scale": 0.05, #0.05,
        "dynamics_noise_scale": 0.7,  #0.01,
        "obs_noise_scale": 0.01
    },
    "ImprovedHopper-v0": {
        "action_noise_scale": 0.03,
        "dynamics_noise_scale": 0.02,
        "obs_noise_scale": 0.01
    },
    "ImprovedWalker2d-v0": {
        "action_noise_scale": 0.03,
        "dynamics_noise_scale": 0.02,
        "obs_noise_scale": 0.01
    },
    "ImprovedAnt-v0": {
        "action_noise_scale": 0.03,
        "dynamics_noise_scale": 0.02,
        "obs_noise_scale": 0.01
    }
}

####################################################################
# 3) Global config
####################################################################
num_seeds = 20
TEST_ITERATIONS = 150
discount = 0.99
MAX_UCT_DEPTH = 100
SCALE_UCB = 10.0
MAX_VALUE = 1e100

# We'll define discretization logic inside a helper
def build_discretized_actions(envname, dim):
    """
    Return a list of discretized actions for the given env.
    For example, if dim=1, we build a uniform grid in [min_action, max_action].
    If dim=2, we build a 2D grid. This logic is adapted from your "UCT" snippet.
    """
    if dim == 1:
        n_actions = 10
        if envname == "Continuous-CartPole-v0":
            return np.linspace(-1.0, 1.0, n_actions).reshape(-1,1)
        elif envname == "StochasticPendulum-v0":
            return np.linspace(-2.0, 2.0, n_actions).reshape(-1,1)
        elif envname == "StochasticMountainCarContinuous-v0":
            return np.linspace(-1.0, 1.0, n_actions).reshape(-1,1)
        elif envname == "StochasticContinuousAcrobot-v0":
            return np.linspace(-1.0, 1.0, n_actions).reshape(-1,1)
        else:
            return np.linspace(-1.0, 1.0, n_actions).reshape(-1,1)
    elif dim <= 4:
        # For small dimensions, use grid discretization
        n_per_dim = int(round(10 ** (1.0 / dim)))  # Keep total around 10^dim actions
        n_per_dim = max(3, min(10, n_per_dim))  # Clamp to reasonable range

        # Create grid
        axes = [np.linspace(-1.0, 1.0, n_per_dim) for _ in range(dim)]
        mesh = np.meshgrid(*axes, indexing='ij')
        points = np.stack([m.ravel() for m in mesh], axis=-1)
        return [point.astype(np.float32) for point in points]
    else:
        # For high dimensions, use random sampling
        n_samples = min(1000, int(3 ** dim))  # Reasonable number for high-dim
        samples = np.random.rand(n_samples, dim)
        samples = samples * 2.0 - 1.0  # Map to [-1, 1]
        return [sample.astype(np.float32) for sample in samples]

####################################################################
# Node classes for UCT (unchanged)
####################################################################
class Node:
    def __init__(self, parent, action, discretized_actions, env):
        self.parent = parent
        self.action = action  # shape=(dim,) or None if root
        self.children = set()
        self.value_sum = 0.0
        self.visit_count = 0
        # We'll step from parent's snapshot to get our snapshot
        if parent is None:
            # root node, no parent step
            self.snapshot = None
            self.obs = None
            self.immediate_reward = 0.0
            self.is_done = False
        else:
            # Ensure action is properly formatted
            if self.action is not None:
                action_array = np.asarray(self.action, dtype=np.float32)
            else:
                raise ValueError("Non-root node must have an action")

            # Make sure parent has a snapshot
            if parent.snapshot is None:
                raise ValueError("Parent node must have a valid snapshot")

            snap, obs, r, done, _ = env.get_result(parent.snapshot, action_array)
            self.snapshot = snap
            self.obs = obs
            self.immediate_reward = r
            self.is_done = done
        self.discretized_actions = discretized_actions

    def __repr__(self):
        return f"Node(action={self.action}, visits={self.visit_count}, value={self.value_sum})"

    def safe_delete(self):
        """Recursively delete this node's children."""
        for child in self.children:
            child.safe_delete()
        del self.parent
        del self.children
        # Python's garbage collector will handle final removal,
        # but this is enough to break references.

    def is_root(self):
        return (self.parent is None)

    def is_leaf(self):
        return len(self.children) == 0

    def get_mean_value(self):
        return self.value_sum / self.visit_count if self.visit_count > 0 else 0.0

    def ucb_score(self):
        if self.is_root():
            return MAX_VALUE
        if self.visit_count == 0:
            return MAX_VALUE
        U = 2.0 * math.sqrt(math.log(self.parent.visit_count) / self.visit_count)
        return self.get_mean_value() + SCALE_UCB * U

    def selection(self):
        """
        UCT selection: recursively pick the child with highest UCB
        until we reach a leaf, returning the leaf node.
        """
        if self.is_leaf():
            return self
        children_list = list(self.children)
        best_child = max(children_list, key=lambda c: c.ucb_score())
        return best_child.selection()

    def expand(self, env):
        """
        If not done, we expand by adding all discretized actions as children.
        Then pick a leaf from those children for further steps.
        """
        if self.is_done:
            return self  # no expansion from terminal
        if len(self.children) == 0:
            for act in self.discretized_actions:
                # each act is shape=(dim,)
                node = Node(self, act, self.discretized_actions, env)
                self.children.add(node)
        return self.selection()

    def rollout(self, env, max_depth=MAX_UCT_DEPTH):
        """ random rollout from self snapshot for at most `max_depth` steps """
        if self.is_done:
            return 0.0
        env.load_snapshot(self.snapshot)
        total = 0.0
        for _ in range(max_depth):
            act = random.choice(self.discretized_actions)
            obs, r, done, _ = env.step(act)
            total += r
            if done:
                break
        return total

    def back_propagate(self, rollout_reward):
        """
        The total return from child's trajectory is `immediate_reward + rollout_reward`.
        We add that to the parent's stats, then continue up.
        """
        node_val = self.immediate_reward + rollout_reward
        self.value_sum += node_val
        self.visit_count += 1
        if not self.is_root():
            self.parent.back_propagate(rollout_reward)

class Root(Node):
    """Root node that doesn't need an action from a parent."""
    def __init__(self, snapshot, obs, discretized_actions):
        super().__init__(parent=None, action=None, discretized_actions=discretized_actions, env=None)
        self.snapshot = snapshot
        self.obs = obs
        self.immediate_reward = 0.0
        self.is_done = False

    @staticmethod
    def to_root(node):
        # convert a child node to root
        root = Root(node.snapshot, node.obs, node.discretized_actions)
        # copy over the stats
        root.children = node.children
        root.value_sum = node.value_sum
        root.visit_count = node.visit_count
        root.is_done = node.is_done
        return root

def plan_mcts(root, n_iter, env):
    for _ in range(n_iter):
        leaf = root.selection()
        if leaf.is_done:
            leaf.back_propagate(0.0)
        else:
            new_leaf = leaf.expand(env)
            rr = new_leaf.rollout(env, max_depth=MAX_UCT_DEPTH)
            new_leaf.back_propagate(rr)

####################################################################
# 5) Main experiment - UPDATED WITH FAST HIGH-DIM ENVS
####################################################################
if __name__ == "__main__":
    results_filename = "uct_results_high_dim.txt"
    f_out = open(results_filename, "a")

    # We'll do iteration counts in a geometric progression
    base = 1000 ** (1.0 / 15.0)
    samples = [int(3 * (base ** i)) for i in range(16)]
    # samples = [1000, 2000,3000,4000,5000,6000]
    samples_to_use = samples[0:6]

    # config
    num_seeds = 20

    for envname in env_names:
        print(f"\n{'='*60}")
        print(f"Starting experiments for {envname}")
        print(f"{'='*60}")

        # A) Build environment with noise if needed - UPDATED FOR FAST HIGH-DIM ENVS
        stoch_kwargs = ENV_NOISE_CONFIG.get(envname, {})

        base_env = gym.make(envname, **stoch_kwargs).env

        # B) figure out dimension, discretized_actions, etc. - UPDATED FOR FAST ENVS
        if envname == "Continuous-CartPole-v0":
            # dimension 1, with env.min_action / env.max_action
            dim = 1
            max_depth = 50
            # build discrete actions
            n_actions = 10
            d_actions = np.linspace(base_env.min_action, base_env.max_action, n_actions)
            discretized_actions = [np.array([a], dtype=np.float32) for a in d_actions]
        elif envname == "StochasticPendulum-v0":
            dim = 1
            max_depth = 50
            n_actions = 10
            d_actions = np.linspace(-2.0, 2.0, n_actions)
            discretized_actions = [np.array([a], dtype=np.float32) for a in d_actions]
        elif envname == "StochasticMountainCarContinuous-v0":
            dim = 1
            max_depth = 50
            n_actions = 10
            d_actions = np.linspace(-1.0, 1.0, n_actions)
            discretized_actions = [np.array([a], dtype=np.float32) for a in d_actions]
        elif envname == "StochasticContinuousAcrobot-v0":
            dim = 1
            max_depth = 50
            n_actions = 10
            d_actions = np.linspace(-1.0, 1.0, n_actions)
            discretized_actions = [np.array([a], dtype=np.float32) for a in d_actions]
        elif envname == "ImprovedHopper-v0":
            dim = 3
            max_depth = 100
            discretized_actions = build_discretized_actions(envname, dim)
        elif envname == "ImprovedWalker2d-v0":
            dim = 6
            max_depth = 100
            discretized_actions = build_discretized_actions(envname, dim)
        elif envname == "ImprovedAnt-v0":
            dim = 8
            max_depth = 100
            discretized_actions = build_discretized_actions(envname, dim)
        else:
            # fallback
            dim = 1
            max_depth = 50
            discretized_actions = [np.array([0.0], dtype=np.float32)]

        print(f"\nEnvironment: {envname}")
        print(f"Action dimension: {dim}")
        print(f"Number of discretized actions: {len(discretized_actions)}")
        print(f"Max depth: {max_depth}")

        # Wrap in SnapshotEnv
        if envname.startswith("Fast"):
            planning_env = SnapshotEnv(gym.make(envname, **stoch_kwargs).env)
        else:
            planning_env = SnapshotEnv(gym.make(envname, **stoch_kwargs).env)

        root_obs_ori = planning_env.reset()
        root_snapshot_ori = planning_env.get_snapshot()

        # Verify we have a valid snapshot
        if root_snapshot_ori is None:
            raise ValueError(f"Failed to get initial snapshot for {envname}")

        for ITERATIONS in samples_to_use:
            print(f"\nRunning {ITERATIONS} iterations for {envname}")
            seed_returns = []

            for seed_i in range(num_seeds):
                if seed_i % 5 == 0:
                    print(f"  Seed {seed_i}/{num_seeds}")

                random.seed(seed_i)
                np.random.seed(seed_i)

                # copy snapshot
                root_obs = copy.copy(root_obs_ori)
                root_snapshot = copy.copy(root_snapshot_ori)

                # build root node
                root = Root(
                    snapshot=root_snapshot,
                    obs=root_obs,
                    discretized_actions=discretized_actions
                )
                # plan
                plan_mcts(root, n_iter=ITERATIONS, env=planning_env)

                # test phase
                test_env = pickle.loads(root_snapshot)
                total_reward = 0.0
                discount = 0.99
                current_discount = 1.0
                done = False
                TEST_ITERATIONS = 150

                for step_i in range(TEST_ITERATIONS):
                    # pick child with best mean value
                    if len(root.children) == 0:
                        # no children => random
                        best_child = None
                        best_action = np.zeros(dim, dtype=np.float32)
                    else:
                        best_child = max(root.children, key=lambda c: c.get_mean_value())
                        best_action = np.asarray(best_child.action, dtype=np.float32)

                    s, r, done, _ = test_env.step(best_action)
                    total_reward += r * current_discount
                    current_discount *= discount

                    if done:
                        test_env.close()
                        break

                    # re-root
                    for child in list(root.children):
                        if child is not best_child:
                            child.safe_delete()
                            root.children.remove(child)

                    if best_child is None:
                        # No children explored yet - create a fresh root from current state
                        planning_env.load_snapshot(pickle.dumps(test_env))
                        root = Root(
                            snapshot=planning_env.get_snapshot(),
                            obs=s,
                            discretized_actions=discretized_actions
                        )
                    else:
                        root = Root.to_root(best_child)

                    # re-plan
                    plan_mcts(root, n_iter=ITERATIONS, env=planning_env)

                if not done:
                    test_env.close()

                seed_returns.append(total_reward)

            mean_return = statistics.mean(seed_returns)
            std_return = statistics.pstdev(seed_returns)
            interval = 2.0 * std_return

            msg = (f"Env={envname}, ITER={ITERATIONS}: "
                   f"Mean={mean_return:.3f} ± {interval:.3f} "
                   f"(over {num_seeds} seeds)")
            print(msg)
            f_out.write(msg + "\n")
            f_out.flush()  # Ensure immediate writing

    f_out.close()
    print("Done! Results saved to", results_filename)
